import torch
from torch import autograd, nn
import torch.nn.functional as F

from itertools import repeat

from .legacy import sparse24

TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
    from torch._six import container_abcs
else:
    import collections.abc as container_abcs


class SparseTranspose(autograd.Function):
    """" Prune the unimprotant weight for the forwards phase but pass the gradient to dense weight using SR-STE in the backwards phase"""

    @staticmethod
    def forward(ctx, weight, N, M, counter, _, absorb_mean):
        weight.mask = weight.mask.to(weight)
        weight.old_mask = weight.old_mask.to(weight)
        weight.sparse = weight.sparse.to(weight)

        if weight.cnt == 0:
            weight_temp = weight.detach()
            weight.sparse, weight_mask = sparse24(weight_temp)
            if getattr(weight, 'scale', None) is None:
                weight.scale = torch.dot(torch.flatten(weight), torch.flatten(weight.sparse)) / torch.dot(
                    torch.flatten(weight.sparse), torch.flatten(weight.sparse))
            weight.old_mask.data.copy_(weight.mask.data)
            weight.mask = weight_mask
        weight.cnt += 1
        ctx.save_for_backward(weight.scale.clone())
        return weight.sparse.clone() * weight.scale, weight.mask

    @staticmethod
    def backward(ctx, grad_output, _):
        scale, = ctx.saved_tensors
        return grad_output, None, None, None, None, None


class SparseLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, bias: bool = True, N=2, M=4, decay=0.0002, **kwargs):
        self.N = N
        self.M = M
        super(SparseLinear, self).__init__(in_features, out_features, bias=bias)
        self.weight.counter = 0
        self.weight.sparse = torch.empty_like(self.weight)
        self.weight.cnt = 0

    def get_sparse_weights(self):
        return SparseTranspose.apply(self.weight, self.N, self.M, self.weight.counter, None, False)

    def forward(self, x):
        if self.training:
            if self.weight.cnt == 0:
                self.weight.counter += 1
            # self.weight.freq = 40
            w, mask = self.get_sparse_weights()
            setattr(self.weight, "mask", mask)
        else:
            w = self.weight.sparse.clone() * self.weight.scale
        x = F.linear(x, w, self.bias)
        return x
